{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "36972001-cbed-46f6-8bed-e57feec3bbd4",
   "metadata": {},
   "source": [
    "# 使用领域（私有）数据微调 ChatGLM3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f2386615-d1d6-40c9-a014-b2bce85838ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch built with:\n",
      "  - GCC 9.3\n",
      "  - C++ Version: 201703\n",
      "  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications\n",
      "  - Intel(R) MKL-DNN v3.3.2 (Git Hash 2dc95a2ad0841e29db8b22fbccaf3e5da7992b01)\n",
      "  - OpenMP 201511 (a.k.a. OpenMP 4.5)\n",
      "  - LAPACK is enabled (usually provided by MKL)\n",
      "  - NNPACK is enabled\n",
      "  - CPU capability usage: AVX512\n",
      "  - CUDA Runtime 12.1\n",
      "  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90\n",
      "  - CuDNN 8.9.2\n",
      "  - Magma 2.6.1\n",
      "  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.3.0, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=1, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF, \n",
      " _CudaDeviceProperties(name='Tesla T4', major=7, minor=5, total_memory=14930MB, multi_processor_count=40)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "print(torch.__config__.show(), torch.cuda.get_device_properties(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "280d5f7b-dada-49e1-81ee-d28a32900423",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义全局变量和参数\n",
    "model_name_or_path = 'THUDM/chatglm3-6b'  # 模型ID或本地路径\n",
    "# train_data_path = 'data/zhouyi_dataset_handmade.csv'    # 训练数据路径\n",
    "train_data_path = 'data/zhouyi_dataset_20240118_152413.csv'    # 训练数据路径(批量生成数据集）\n",
    "eval_data_path = None                     # 验证数据路径，如果没有则设置为None\n",
    "seed = 8                                 # 随机种子\n",
    "max_input_length = 512                    # 输入的最大长度\n",
    "max_output_length = 1536                  # 输出的最大长度\n",
    "lora_rank = 16                             # LoRA秩\n",
    "lora_alpha = 32                           # LoRA alpha值\n",
    "lora_dropout = 0.05                       # LoRA Dropout率\n",
    "prompt_text = ''                          # 所有数据前的指令文本"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "461e2caf-4dd4-4d4d-a8f6-7bfd674d5754",
   "metadata": {},
   "source": [
    "## 数据处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2c4d167c-0f39-4b11-8fad-ec608c2b3c3e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['content', 'summary'],\n",
      "        num_rows: 160\n",
      "    })\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"csv\", data_files=train_data_path)\n",
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4d02d0c6-fc78-429c-9193-1453a91aed81",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import ClassLabel, Sequence\n",
    "import random\n",
    "import pandas as pd\n",
    "from IPython.display import display, HTML\n",
    "\n",
    "def show_random_elements(dataset, num_examples=10):\n",
    "    assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
    "    picks = []\n",
    "    for _ in range(num_examples):\n",
    "        pick = random.randint(0, len(dataset)-1)\n",
    "        while pick in picks:\n",
    "            pick = random.randint(0, len(dataset)-1)\n",
    "        picks.append(pick)\n",
    "    \n",
    "    df = pd.DataFrame(dataset[picks])\n",
    "    for column, typ in dataset.features.items():\n",
    "        if isinstance(typ, ClassLabel):\n",
    "            df[column] = df[column].transform(lambda i: typ.names[i])\n",
    "        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):\n",
    "            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])\n",
    "    display(HTML(df.to_html()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4787d6f7-27ab-4822-b64e-2794bf52ec91",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>content</th>\n",
       "      <th>summary</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>周易的\"讼卦讲述了什么？</td>\n",
       "      <td>在周易中，讼卦是一个极具深意的卦象。上卦为乾（天），下卦为坎（水），两者相背而行，代表天与水违行的状况，象征着事理乖舛和争讼之象。讼卦中有利可图，但必须警惕戒惧，事情中间吉利，但最终会有凶险。在卜卦时，利于会见贵族王公，但不利于涉水渡河。\\n\\n讼卦的核心哲学是：开始可能顺利，有所收获，但随后会遇到困难和挫折。因此，务必慎之又慎，不得固执已见，避免介入诉讼纠纷的争执之中。退让而不固执，求得化解，安于正理，可免除意外之灾。陷入争讼，即使获胜，最后还得失去，得不偿失。\\n\\n讼卦的经商指引是：和气生财，吃亏是福，切勿追求不义之财。在商业谈判中要坚持公正、公平、互利的原则，尽量避免发生冲突。\\n\\n对于决策，讼卦提醒我们，争强好胜，不安于现状，为改变命运和超越他人而奋斗。但缺乏持之以恒的毅力，容易得罪他人，带来诉讼之灾。因此，接受教训，引以为戒，可功成名就。\\n\\n讼卦所蕴含的智慧是：在面对争端和异见时，要善于退让和求和，坚守正道，谨慎处事，以避免不必要的冲突和损失。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>周易中的\"乾卦含义是什么？</td>\n",
       "      <td>\"乾卦\"\\nsummary: \"《易经》中的乾卦是六十四卦中的首卦，象征天，由六个阳爻组成，代表着刚健强劲的特性。其卦辞为“元、亨、利、贞”，预示着吉祥如意，同时也教导人们遵守天道的德行。乾卦所蕴含的核心哲学是：天道刚健，运行不已，君子观此卦象，从而以天为法，自强不息。\"\\n\\ncomment: \"在传统解卦中，乾卦预示着大吉大利，事业如日中天，但也提醒要警惕盛极必衰的道理。经商方面顺利发展，但要冷静分析形势，坚持商业道德。对于婚恋，尽管阳盛阴衰，但刚柔可相济，最终形成美满结果。总体而言，乾卦代表着充满活力和力量的时机，但也需要保持谦逊谨慎的态度，以应对可能出现的困难。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>坤卦的基本意义是什么？</td>\n",
       "      <td>坤卦是周易中的一卦，代表大地的形势，象征顺顺利利，顺应天道。它由两个坤卦叠加而成，全为阴卦，具有纯阴性质。在卜问中，坤卦预示着利于雌马的贞正之象，吉利的事情会发生。同时，出行需谨慎，起初可能会迷失方向，但最终会找到主人，带来吉利。前往西南会得到朋友，前往东北会失去朋友。在安定的状态下，能够得到吉祥。\\n\\n坤卦的形势平铺舒展，代表地道生育抚养万物，依循天时地利，象征以厚德载物。在决策上，应该顺从运势，守正静待，不宜急进，须以静制动为宜。事业发展上，要注重内心修养，广纳众意，以柔克刚为原则，力求和谐共生。经商求名上，同样要稳健行事，不宜冒险急进，在合作中共同完成事业。\\n\\n综上所述，坤卦意味着安稳、柔顺、温和、顺应天道，需要以厚德载物，以守正静待为策略，以诚信待人，收敛于己，力求温和和谐。\"\\n\\nformat:\"卦名:坤卦\\n解释:坤卦是周易中的一卦，代表大地的形势，象征顺顺利利，顺应天道。它由两个坤卦叠加而成，全为阴卦，具有纯阴性质。在卜问中，坤卦预示着利于雌马的贞正之象，吉利的事情会发生。同时，出行需谨慎，起初可能会迷失方向，但最终会找到主人，带来吉利。前往西南会得到朋友，前往东北会失去朋友。在安定的状态下，能够得到吉祥。\\n\\n坤卦的形势平铺舒展，代表地道生育抚养万物，依循天时地利，象征以厚德载物。在决策上，应该顺从运势，守正静待，不宜急进，须以静制动为宜。事业发展上，要注重内心修养，广纳众意，以柔克刚为原则，力求和谐共生。经商求名上，同样要稳健行事，不宜冒险急进，在合作中共同完成事业。\\n\\n综上所述，坤卦意味着安稳、柔顺、温和、顺应天道，需要以厚德载物，以守正静待为策略，以诚信待人，收敛于己，力求温和和谐。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>坤卦涉及哪些哲学思想？</td>\n",
       "      <td>坤卦是周易中的一卦，代表大地的形势，象征顺顺利利，顺应天道。它由两个坤卦叠加而成，全为阴卦，具有纯阴性质。在卜问中，坤卦预示着利于雌马的贞正之象，吉利的事情会发生。同时，出行需谨慎，起初可能会迷失方向，但最终会找到主人，带来吉利。前往西南会得到朋友，前往东北会失去朋友。在安定的状态下，能够得到吉祥。\\n\\n坤卦的形势平铺舒展，代表地道生育抚养万物，依循天时地利，象征以厚德载物。在决策上，应该顺从运势，守正静待，不宜急进，须以静制动为宜。事业发展上，要注重内心修养，广纳众意，以柔克刚为原则，力求和谐共生。经商求名上，同样要稳健行事，不宜冒险急进，在合作中共同完成事业。\\n\\n综上所述，坤卦意味着安稳、柔顺、温和、顺应天道，需要以厚德载物，以守正静待为策略，以诚信待人，收敛于己，力求温和和谐。\"\\n\\nformat:\"卦名:坤卦\\n解释:坤卦是周易中的一卦，代表大地的形势，象征顺顺利利，顺应天道。它由两个坤卦叠加而成，全为阴卦，具有纯阴性质。在卜问中，坤卦预示着利于雌马的贞正之象，吉利的事情会发生。同时，出行需谨慎，起初可能会迷失方向，但最终会找到主人，带来吉利。前往西南会得到朋友，前往东北会失去朋友。在安定的状态下，能够得到吉祥。\\n\\n坤卦的形势平铺舒展，代表地道生育抚养万物，依循天时地利，象征以厚德载物。在决策上，应该顺从运势，守正静待，不宜急进，须以静制动为宜。事业发展上，要注重内心修养，广纳众意，以柔克刚为原则，力求和谐共生。经商求名上，同样要稳健行事，不宜冒险急进，在合作中共同完成事业。\\n\\n综上所述，坤卦意味着安稳、柔顺、温和、顺应天道，需要以厚德载物，以守正静待为策略，以诚信待人，收敛于己，力求温和和谐。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>比卦在周易中怎样表达教育的概念？</td>\n",
       "      <td>比卦\"是周易卦象中的一枚卦，由下卦坤（地）上卦坎（水）组成，预示着吉利的变化。在卜筮时，再次占卜依然吉利，预示长期稳定无灾祸。然而，当不愿臣服的邦国未能前来朝贺时，将会带来危险。\\n\\n在《象辞》中，比卦被描述为地上有水的情景，反映了相亲相依相互依赖的意义。先王观此卦象，取法于水附大地，地纳江河之象，因此此卦被解释为建立万国，亲近诸侯。\\n\\n北宋易学家邵雍解释认为，比卦代表水在地面上流动，人际关系亲密和睦，各种事情无忧无虑。\\n\\n台湾国学大儒傅佩荣解释认为，比卦的时运是众人相贺，财运是善人相扶，家宅方面是长久美满，身体方面则需早求治心腹水肿。\\n\\n传统解卦认为，比卦预示着相亲相辅，宽宏无私，精诚团结的道理。在运势上，表示平顺，能得贵人提拔，事业宜速战速决，不宜过度迟疑。对于事业、经商、求名、婚恋等方面都有积极的影响，提醒人们待人宽厚、正直，主动热情，并谨慎选择朋友。\\n\\n总之，\"比卦</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_random_elements(dataset[\"train\"], num_examples=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b99c93ee-dfaf-4f38-b499-202accf20a1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c680f5af-4c73-4f40-adfa-0e404e10d244",
   "metadata": {},
   "outputs": [],
   "source": [
    "# tokenize_func 函数\n",
    "def tokenize_func(example, tokenizer, ignore_label_id=-100):\n",
    "    \"\"\"\n",
    "    对单个数据样本进行tokenize处理。\n",
    "\n",
    "    参数:\n",
    "    example (dict): 包含'content'和'summary'键的字典，代表训练数据的一个样本。\n",
    "    tokenizer (transformers.PreTrainedTokenizer): 用于tokenize文本的tokenizer。\n",
    "    ignore_label_id (int, optional): 在label中用于填充的忽略ID，默认为-100。\n",
    "\n",
    "    返回:\n",
    "    dict: 包含'tokenized_input_ids'和'labels'的字典，用于模型训练。\n",
    "    \"\"\"\n",
    "\n",
    "    # 构建问题文本\n",
    "    question = prompt_text + example['content']\n",
    "    if example.get('input', None) and example['input'].strip():\n",
    "        question += f'\\n{example[\"input\"]}'\n",
    "\n",
    "    # 构建答案文本\n",
    "    answer = example['summary']\n",
    "\n",
    "    # 对问题和答案文本进行tokenize处理\n",
    "    q_ids = tokenizer.encode(text=question, add_special_tokens=False)\n",
    "    a_ids = tokenizer.encode(text=answer, add_special_tokens=False)\n",
    "\n",
    "    # 如果tokenize后的长度超过最大长度限制，则进行截断\n",
    "    if len(q_ids) > max_input_length - 2:  # 保留空间给gmask和bos标记\n",
    "        q_ids = q_ids[:max_input_length - 2]\n",
    "    if len(a_ids) > max_output_length - 1:  # 保留空间给eos标记\n",
    "        a_ids = a_ids[:max_output_length - 1]\n",
    "\n",
    "    # 构建模型的输入格式\n",
    "    input_ids = tokenizer.build_inputs_with_special_tokens(q_ids, a_ids)\n",
    "    question_length = len(q_ids) + 2  # 加上gmask和bos标记\n",
    "\n",
    "    # 构建标签，对于问题部分的输入使用ignore_label_id进行填充\n",
    "    labels = [ignore_label_id] * question_length + input_ids[question_length:]\n",
    "\n",
    "    return {'input_ids': input_ids, 'labels': labels}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b99af28d-86b1-41c1-84da-2d2a4efc200f",
   "metadata": {},
   "outputs": [],
   "source": [
    "column_names = dataset['train'].column_names\n",
    "tokenized_dataset = dataset['train'].map(\n",
    "    lambda example: tokenize_func(example, tokenizer),\n",
    "    batched=False, \n",
    "    remove_columns=column_names\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b749ba67-e890-4a42-a652-1eacf030ad47",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_dataset = tokenized_dataset.shuffle(seed=seed)\n",
    "tokenized_dataset = tokenized_dataset.flatten_indices()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5fb9d942-e138-4ff4-a475-65d6e4f7d0c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from typing import List, Dict, Optional\n",
    "\n",
    "# DataCollatorForChatGLM 类\n",
    "class DataCollatorForChatGLM:\n",
    "    \"\"\"\n",
    "    用于处理批量数据的DataCollator，尤其是在使用 ChatGLM 模型时。\n",
    "\n",
    "    该类负责将多个数据样本（tokenized input）合并为一个批量，并在必要时进行填充(padding)。\n",
    "\n",
    "    属性:\n",
    "    pad_token_id (int): 用于填充(padding)的token ID。\n",
    "    max_length (int): 单个批量数据的最大长度限制。\n",
    "    ignore_label_id (int): 在标签中用于填充的ID。\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, pad_token_id: int, max_length: int = 2048, ignore_label_id: int = -100):\n",
    "        \"\"\"\n",
    "        初始化DataCollator。\n",
    "\n",
    "        参数:\n",
    "        pad_token_id (int): 用于填充(padding)的token ID。\n",
    "        max_length (int): 单个批量数据的最大长度限制。\n",
    "        ignore_label_id (int): 在标签中用于填充的ID，默认为-100。\n",
    "        \"\"\"\n",
    "        self.pad_token_id = pad_token_id\n",
    "        self.ignore_label_id = ignore_label_id\n",
    "        self.max_length = max_length\n",
    "\n",
    "    def __call__(self, batch_data: List[Dict[str, List]]) -> Dict[str, torch.Tensor]:\n",
    "        \"\"\"\n",
    "        处理批量数据。\n",
    "\n",
    "        参数:\n",
    "        batch_data (List[Dict[str, List]]): 包含多个样本的字典列表。\n",
    "\n",
    "        返回:\n",
    "        Dict[str, torch.Tensor]: 包含处理后的批量数据的字典。\n",
    "        \"\"\"\n",
    "        # 计算批量中每个样本的长度\n",
    "        len_list = [len(d['input_ids']) for d in batch_data]\n",
    "        batch_max_len = max(len_list)  # 找到最长的样本长度\n",
    "\n",
    "        input_ids, labels = [], []\n",
    "        for len_of_d, d in sorted(zip(len_list, batch_data), key=lambda x: -x[0]):\n",
    "            pad_len = batch_max_len - len_of_d  # 计算需要填充的长度\n",
    "            # 添加填充，并确保数据长度不超过最大长度限制\n",
    "            ids = d['input_ids'] + [self.pad_token_id] * pad_len\n",
    "            label = d['labels'] + [self.ignore_label_id] * pad_len\n",
    "            if batch_max_len > self.max_length:\n",
    "                ids = ids[:self.max_length]\n",
    "                label = label[:self.max_length]\n",
    "            input_ids.append(torch.LongTensor(ids))\n",
    "            labels.append(torch.LongTensor(label))\n",
    "\n",
    "        # 将处理后的数据堆叠成一个tensor\n",
    "        input_ids = torch.stack(input_ids)\n",
    "        labels = torch.stack(labels)\n",
    "\n",
    "        return {'input_ids': input_ids, 'labels': labels}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "478eec3d-3b00-4d67-9bdc-88535c10ca27",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 准备数据整理器\n",
    "data_collator = DataCollatorForChatGLM(pad_token_id=tokenizer.pad_token_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "951d4a53-ed02-4a2f-a6be-1b95052786ce",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "d0ad8947-cfb8-47b0-90fd-5a8c93eed85a",
   "metadata": {},
   "source": [
    "## 加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "102589d4-f07c-4952-abf2-42eed291c01d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0ecb1f4b97d74a4ea393170c49d6063e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it).Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModel, BitsAndBytesConfig\n",
    "\n",
    "_compute_dtype_map = {\n",
    "    'fp32': torch.float32,\n",
    "    'fp16': torch.float16,\n",
    "    'bf16': torch.bfloat16\n",
    "}\n",
    "\n",
    "# QLoRA 量化配置\n",
    "q_config = BitsAndBytesConfig(load_in_4bit=True,\n",
    "                              bnb_4bit_quant_type='nf4',\n",
    "                              bnb_4bit_use_double_quant=True,\n",
    "                              bnb_4bit_compute_dtype=_compute_dtype_map['bf16'])\n",
    "# 加载量化后模型\n",
    "model = AutoModel.from_pretrained(model_name_or_path,\n",
    "                                  quantization_config=q_config,\n",
    "                                  device_map='auto',\n",
    "                                  trust_remote_code=True)\n",
    "\n",
    "model.supports_gradient_checkpointing = True  \n",
    "model.gradient_checkpointing_enable()\n",
    "model.enable_input_require_grads()\n",
    "\n",
    "model.config.use_cache = False  # silence the warnings. Please re-enable for inference!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "28ce4c2d-d2a3-4e0a-a6a8-b66c8901b6a4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it).Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model.\n"
     ]
    }
   ],
   "source": [
    "from peft import TaskType, LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
    "from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING\n",
    "\n",
    "kbit_model = prepare_model_for_kbit_training(model)\n",
    "target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING['chatglm']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "4e0d219c-07a6-42b5-94dd-ac0a27864dcc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['query_key_value']"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5f5c327f-c34d-4c22-a72f-7a40757514b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "lora_config = LoraConfig(\n",
    "    target_modules=target_modules,\n",
    "    r=lora_rank,\n",
    "    lora_alpha=lora_alpha,\n",
    "    lora_dropout=lora_dropout,\n",
    "    bias='none',\n",
    "    inference_mode=False,\n",
    "    task_type=TaskType.CAUSAL_LM\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8c90d7ca-de06-416c-971d-b416e0f52ebd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 3,899,392 || all params: 6,247,483,392 || trainable%: 0.06241540401681151\n"
     ]
    }
   ],
   "source": [
    "qlora_model = get_peft_model(kbit_model, lora_config)\n",
    "qlora_model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26f3c6c2-28b5-4d2d-b611-a4c3bf90f133",
   "metadata": {},
   "source": [
    "### QLoRA 微调模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1f4233fa-adee-4fa2-9893-bb149636208b",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_epochs = 3\n",
    "output_dir = f\"models/{model_name_or_path}-epoch{train_epochs}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "4f07e78f-8ff8-439a-8728-5b18b6a7626d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=output_dir,                            # 输出目录\n",
    "    per_device_train_batch_size=8,                     # 每个设备的训练批量大小\n",
    "    gradient_accumulation_steps=1,                     # 梯度累积步数\n",
    "    learning_rate=1e-3,                                # 学习率\n",
    "    num_train_epochs=train_epochs,                     # 训练轮数\n",
    "    lr_scheduler_type=\"linear\",                        # 学习率调度器类型\n",
    "    warmup_ratio=0.1,                                  # 预热比例\n",
    "    logging_steps=1,                                 # 日志记录步数\n",
    "    save_strategy=\"steps\",                             # 模型保存策略\n",
    "    save_steps=10,                                    # 模型保存步数\n",
    "    optim=\"adamw_torch\",                               # 优化器类型\n",
    "    fp16=True,                                        # 是否使用混合精度训练\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "12d86b61-0cf3-4b87-b866-3e81f58ad064",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Detected kernel version 4.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(\n",
    "        model=qlora_model,\n",
    "        args=training_args,\n",
    "        train_dataset=tokenized_dataset,\n",
    "        data_collator=data_collator\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "2bc298ef-3c6d-45cb-98d8-ccac43d386d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.11/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='60' max='60' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [60/60 12:23, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>3.594100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>4.049100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>3.091200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>3.381700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>3.547800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>2.610200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>2.657900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>3.163900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>2.277500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>2.295200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>1.855400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>1.883800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>1.551500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.986800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>0.991300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>0.529100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>0.616300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>0.552400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>0.494700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>0.279100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21</td>\n",
       "      <td>0.253600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22</td>\n",
       "      <td>0.197100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23</td>\n",
       "      <td>0.125400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24</td>\n",
       "      <td>0.123500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25</td>\n",
       "      <td>0.082500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26</td>\n",
       "      <td>0.046400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27</td>\n",
       "      <td>0.042600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28</td>\n",
       "      <td>0.021200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>29</td>\n",
       "      <td>0.023800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>0.021400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>31</td>\n",
       "      <td>0.025900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>32</td>\n",
       "      <td>0.016200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>33</td>\n",
       "      <td>0.016900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>34</td>\n",
       "      <td>0.012700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>35</td>\n",
       "      <td>0.010200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>36</td>\n",
       "      <td>0.009700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>37</td>\n",
       "      <td>0.013400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>38</td>\n",
       "      <td>0.011100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>39</td>\n",
       "      <td>0.008600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>0.006800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>41</td>\n",
       "      <td>0.005300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>42</td>\n",
       "      <td>0.005900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>43</td>\n",
       "      <td>0.007700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>44</td>\n",
       "      <td>0.005600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>45</td>\n",
       "      <td>0.004800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>46</td>\n",
       "      <td>0.004500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>47</td>\n",
       "      <td>0.005600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>48</td>\n",
       "      <td>0.005100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>49</td>\n",
       "      <td>0.004700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>0.004500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>51</td>\n",
       "      <td>0.004500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>52</td>\n",
       "      <td>0.003300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>53</td>\n",
       "      <td>0.004300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>54</td>\n",
       "      <td>0.004500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>55</td>\n",
       "      <td>0.004100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>56</td>\n",
       "      <td>0.003300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>57</td>\n",
       "      <td>0.003200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>58</td>\n",
       "      <td>0.002600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>59</td>\n",
       "      <td>0.003900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>0.003200</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.11/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.11/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.11/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.11/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.11/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=60, training_loss=0.6928082378969217, metrics={'train_runtime': 756.4767, 'train_samples_per_second': 0.635, 'train_steps_per_second': 0.079, 'total_flos': 9433079239507968.0, 'train_loss': 0.6928082378969217, 'epoch': 3.0})"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "cc1156a8-a6c4-42f7-bbaf-a6736d51bd0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.model.save_pretrained(output_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a1f764a-bbde-4862-9665-e5c0d67c640a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
